{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vjAC2mZnb4nz"
   },
   "source": [
    "# Image transformations\n",
    "\n",
    "This notebook shows new features of torchvision image transformations. \n",
    "\n",
    "Prior to v0.8.0, transforms in torchvision have traditionally been PIL-centric and presented multiple limitations due to that. Now, since v0.8.0, transforms implementations are Tensor and PIL compatible and we can achieve the following new \n",
    "features:\n",
    "- transform multi-band torch tensor images (with more than 3-4 channels) \n",
    "- torchscript transforms together with your model for deployment\n",
    "- support for GPU acceleration\n",
    "- batched transformation such as for videos\n",
    "- read and decode data directly as torch tensor with torchscript support (for PNG and JPEG image formats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "id": "btaDWPDbgIyW",
    "outputId": "8a83d408-f643-42da-d247-faf3a1bd3ae0"
   },
   "outputs": [],
   "source": [
    "import torch, torchvision\n",
    "torch.__version__, torchvision.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9Vj9draNb4oA"
   },
   "source": [
    "## Transforms on CPU/CUDA tensor images\n",
    "\n",
    "Let's show how to apply transformations on images opened directly as a torch tensors.\n",
    "Now, torchvision provides image reading functions for PNG and JPG images with torchscript support. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Epp3hCy0b4oD"
   },
   "outputs": [],
   "source": [
    "from torchvision.datasets.utils import download_url\n",
    "\n",
    "download_url(\"https://farm1.static.flickr.com/152/434505223_8d1890e1e2.jpg\", \".\", \"test-image.jpg\")\n",
    "download_url(\"https://farm3.static.flickr.com/2142/1896267403_24939864ba.jpg\", \".\", \"test-image2.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Y-m7lYDPb4oK"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pylab as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 303
    },
    "id": "5bi8Q7L3b4oc",
    "outputId": "e5de5c73-e16d-4992-ebee-94c7ddf0bf54"
   },
   "outputs": [],
   "source": [
    "from torchvision.io.image import read_image\n",
    "\n",
    "tensor_image = read_image(\"test-image.jpg\")\n",
    "\n",
    "print(\"tensor image info: \", tensor_image.shape, tensor_image.dtype)\n",
    "\n",
    "plt.imshow(tensor_image.numpy().transpose((1, 2, 0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_rgb_image(tensor):\n",
    "    \"\"\"Helper method to get RGB numpy array for plotting\"\"\"\n",
    "    np_img = tensor.cpu().numpy().transpose((1, 2, 0))\n",
    "    m1, m2 = np_img.min(axis=(0, 1)), np_img.max(axis=(0, 1))\n",
    "    return (255.0 * (np_img - m1) / (m2 - m1)).astype(\"uint8\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 322
    },
    "id": "PgWpjxQ3b4pF",
    "outputId": "e9a138e8-b45c-4f75-d849-3b41de0e5472"
   },
   "outputs": [],
   "source": [
    "import torchvision.transforms as T\n",
    "\n",
    "# to fix random seed is now:\n",
    "torch.manual_seed(12)\n",
    "\n",
    "transforms = T.Compose([\n",
    "    T.RandomCrop(224),\n",
    "    T.RandomHorizontalFlip(p=0.3),\n",
    "    T.ConvertImageDtype(torch.float),\n",
    "    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "out_image = transforms(tensor_image)\n",
    "print(\"output tensor image info: \", out_image.shape, out_image.dtype)\n",
    "\n",
    "plt.imshow(to_rgb_image(out_image))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LmYQB4cxb4pI"
   },
   "source": [
    "Tensor images can be on GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 322
    },
    "id": "S6syYJGEb4pN",
    "outputId": "86bddb64-e648-45f2-c216-790d43cfc26d"
   },
   "outputs": [],
   "source": [
    "out_image = transforms(tensor_image.to(\"cuda\"))\n",
    "print(\"output tensor image info: \", out_image.shape, out_image.dtype, out_image.device)\n",
    "\n",
    "plt.imshow(to_rgb_image(out_image))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jg9TQd7ajfyn"
   },
   "source": [
    "## Scriptable transforms for easier deployment via torchscript\n",
    "\n",
    "Next, we show how to combine input transformations and model's forward pass and use `torch.jit.script` to obtain a single scripted module.\n",
    "\n",
    "**Note:** we have to use only scriptable transformations that should be derived from `torch.nn.Module`. \n",
    "Since v0.8.0, all transformations are scriptable except `Compose`, `RandomChoice`, `RandomOrder`, `Lambda` and those applied on PIL images.\n",
    "The transformations like `Compose` are kept for backward compatibility and can be easily replaced by existing torch modules, like `nn.Sequential`.\n",
    "\n",
    "Let's define a module `Predictor` that transforms input tensor and applies ImageNet pretrained resnet18 model on it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NSDOJ3RajfvO"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as T\n",
    "from torchvision.io.image import read_image\n",
    "from torchvision.models import resnet18\n",
    "\n",
    "\n",
    "class Predictor(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.resnet18 = resnet18(pretrained=True).eval()\n",
    "        self.transforms = nn.Sequential(\n",
    "            T.Resize([256, ]), # We use single int value inside a list due to torchscript type restrictions\n",
    "            T.CenterCrop(224),\n",
    "            T.ConvertImageDtype(torch.float),\n",
    "            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        )\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        with torch.no_grad():\n",
    "            x = self.transforms(x)\n",
    "            y_pred = self.resnet18(x)\n",
    "            return y_pred.argmax(dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZZKDovqej5vA"
   },
   "source": [
    "Now, let's define scripted and non-scripted instances of `Predictor` and apply on multiple tensor images of the same size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GBBMSo7vjfr0"
   },
   "outputs": [],
   "source": [
    "from torchvision.io.image import read_image\n",
    "\n",
    "predictor = Predictor().to(\"cuda\")\n",
    "scripted_predictor = torch.jit.script(predictor).to(\"cuda\")\n",
    "\n",
    "\n",
    "tensor_image1 = read_image(\"test-image.jpg\")\n",
    "tensor_image2 = read_image(\"test-image2.jpg\")\n",
    "batch = torch.stack([tensor_image1[:, -320:, :], tensor_image2[:, -320:, :]]).to(\"cuda\")\n",
    "\n",
    "res1 = scripted_predictor(batch)\n",
    "res2 = predictor(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 501
    },
    "id": "Dmi9r_p-oKsk",
    "outputId": "b9c55e7d-5db1-4975-c485-fecc4075bf47"
   },
   "outputs": [],
   "source": [
    "import json\n",
    "from torchvision.datasets.utils import download_url\n",
    "\n",
    "\n",
    "download_url(\"https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json\", \".\", \"imagenet_class_index.json\")\n",
    "\n",
    "\n",
    "with open(\"imagenet_class_index.json\", \"r\") as h:\n",
    "    labels = json.load(h)\n",
    "\n",
    "\n",
    "plt.figure(figsize=(12, 7))\n",
    "for i, p in enumerate(res1):\n",
    "    plt.subplot(1, 2, i + 1)\n",
    "    plt.title(\"Scripted predictor:\\n{label})\".format(label=labels[str(p.item())]))\n",
    "    plt.imshow(batch[i, ...].cpu().numpy().transpose((1, 2, 0)))\n",
    "\n",
    "\n",
    "plt.figure(figsize=(12, 7))\n",
    "for i, p in enumerate(res2):\n",
    "    plt.subplot(1, 2, i + 1)\n",
    "    plt.title(\"Original predictor:\\n{label})\".format(label=labels[str(p.item())]))\n",
    "    plt.imshow(batch[i, ...].cpu().numpy().transpose((1, 2, 0)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7IYsjzpFqcK8"
   },
   "source": [
    "We save and reload scripted predictor in Python or C++ and use it for inference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 52
    },
    "id": "0kk9LLw5jfol",
    "outputId": "05ea6db7-7fcf-4b74-a763-5f117c14cc00"
   },
   "outputs": [],
   "source": [
    "scripted_predictor.save(\"scripted_predictor.pt\")\n",
    "\n",
    "scripted_predictor = torch.jit.load(\"scripted_predictor.pt\")\n",
    "res1 = scripted_predictor(batch)\n",
    "\n",
    "for i, p in enumerate(res1):\n",
    "    print(\"Scripted predictor: {label})\".format(label=labels[str(p.item())]))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Data reading and decoding functions also support torch script and therefore can be part of the model as well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AnotherPredictor(Predictor):\n",
    "\n",
    "    def forward(self, path: str) -> int:\n",
    "        with torch.no_grad():\n",
    "            x = read_image(path).unsqueeze(0)\n",
    "            x = self.transforms(x)\n",
    "            y_pred = self.resnet18(x)\n",
    "            return int(y_pred.argmax(dim=1).item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-cMwTs3Yjffy"
   },
   "outputs": [],
   "source": [
    "scripted_predictor2 = torch.jit.script(AnotherPredictor())\n",
    "\n",
    "res = scripted_predictor2(\"test-image.jpg\")\n",
    "\n",
    "print(\"Scripted another predictor: {label})\".format(label=labels[str(res)]))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "torchvision_scriptable_transforms.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
